/*!
Generating SPIR-V for ray query operations.
*/

use alloc::{vec, vec::Vec};

use super::{
    Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType,
    Writer,
};
use crate::{arena::Handle, back::spv::LookupRayQueryFunction};

/// helper function to check if a particular flag is set in a u32.
fn write_ray_flags_contains_flags(
    writer: &mut Writer,
    block: &mut Block,
    id: spirv::Word,
    flag: u32,
) -> spirv::Word {
    let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag));
    let zero_id = writer.get_constant_scalar(crate::Literal::U32(0));
    let u32_type_id = writer.get_u32_type_id();
    let bool_ty = writer.get_bool_type_id();

    let and_id = writer.id_gen.next();
    block.body.push(Instruction::binary(
        spirv::Op::BitwiseAnd,
        u32_type_id,
        and_id,
        id,
        bit_id,
    ));

    let eq_id = writer.id_gen.next();
    block.body.push(Instruction::binary(
        spirv::Op::INotEqual,
        bool_ty,
        eq_id,
        and_id,
        zero_id,
    ));

    eq_id
}

impl Writer {
    /// writes a logical and of two scalar booleans
    fn write_logical_and(
        &mut self,
        block: &mut Block,
        one: spirv::Word,
        two: spirv::Word,
    ) -> spirv::Word {
        let id = self.id_gen.next();
        let bool_id = self.get_bool_type_id();
        block.body.push(Instruction::binary(
            spirv::Op::LogicalAnd,
            bool_id,
            id,
            one,
            two,
        ));
        id
    }

    fn write_reduce_and(&mut self, block: &mut Block, mut bools: Vec<spirv::Word>) -> spirv::Word {
        // The combined `and`ed together of all of the bools up to this point.
        let mut current_combined = bools.pop().unwrap();
        for boolean in bools {
            current_combined = self.write_logical_and(block, current_combined, boolean)
        }
        current_combined
    }

    // returns the id of the function, the function, and ids for its arguments.
    fn write_function_signature(
        &mut self,
        arg_types: &[spirv::Word],
        return_ty: spirv::Word,
    ) -> (spirv::Word, Function, Vec<spirv::Word>) {
        let func_ty = self.get_function_type(LookupFunctionType {
            parameter_type_ids: Vec::from(arg_types),
            return_type_id: return_ty,
        });

        let mut function = Function::default();
        let func_id = self.id_gen.next();
        function.signature = Some(Instruction::function(
            return_ty,
            func_id,
            spirv::FunctionControl::empty(),
            func_ty,
        ));

        let mut arg_ids = Vec::with_capacity(arg_types.len());

        for (idx, &arg_ty) in arg_types.iter().enumerate() {
            let id = self.id_gen.next();
            let instruction = Instruction::function_parameter(arg_ty, id);
            function.parameters.push(FunctionArgument {
                instruction,
                handle_id: idx as u32,
            });
            arg_ids.push(id);
        }
        (func_id, function, arg_ids)
    }

    pub(super) fn write_ray_query_get_intersection_function(
        &mut self,
        is_committed: bool,
        ir_module: &crate::Module,
    ) -> spirv::Word {
        if let Some(&word) =
            self.ray_query_functions
                .get(&LookupRayQueryFunction::GetIntersection {
                    committed: is_committed,
                })
        {
            return word;
        }
        let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
        let intersection_type_id = self.get_handle_type_id(ray_intersection);
        let intersection_pointer_type_id =
            self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);

        let flag_type_id = self.get_u32_type_id();
        let flag_pointer_type_id =
            self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);

        let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
            columns: crate::VectorSize::Quad,
            rows: crate::VectorSize::Tri,
            scalar: crate::Scalar::F32,
        });
        let transform_pointer_type_id =
            self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);

        let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
            size: crate::VectorSize::Bi,
            scalar: crate::Scalar::F32,
        });
        let barycentrics_pointer_type_id =
            self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);

        let bool_type_id = self.get_bool_type_id();
        let bool_pointer_type_id =
            self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);

        let scalar_type_id = self.get_f32_type_id();
        let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);

        let argument_type_id = self.get_ray_query_pointer_id();

        let (func_id, mut function, arg_ids) = self.write_function_signature(
            &[argument_type_id, flag_pointer_type_id],
            intersection_type_id,
        );

        let query_id = arg_ids[0];
        let intersection_tracker_id = arg_ids[1];

        let label_id = self.id_gen.next();
        let mut block = Block::new(label_id);

        let blank_intersection = self.get_constant_null(intersection_type_id);
        let blank_intersection_id = self.id_gen.next();
        // This must be before everything else in the function.
        block.body.push(Instruction::variable(
            intersection_pointer_type_id,
            blank_intersection_id,
            spirv::StorageClass::Function,
            Some(blank_intersection),
        ));

        let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
        } else {
            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
        } as _));

        let loaded_ray_query_tracker_id = self.id_gen.next();
        block.body.push(Instruction::load(
            flag_type_id,
            loaded_ray_query_tracker_id,
            intersection_tracker_id,
            None,
        ));
        let proceeded_id = write_ray_flags_contains_flags(
            self,
            &mut block,
            loaded_ray_query_tracker_id,
            super::RayQueryPoint::PROCEED.bits(),
        );
        let finished_proceed_id = write_ray_flags_contains_flags(
            self,
            &mut block,
            loaded_ray_query_tracker_id,
            super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
        );
        let proceed_finished_correct_id = if is_committed {
            finished_proceed_id
        } else {
            let not_finished_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::LogicalNot,
                bool_type_id,
                not_finished_id,
                finished_proceed_id,
            ));
            not_finished_id
        };

        let is_valid_id =
            self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id);

        let valid_id = self.id_gen.next();
        let mut valid_block = Block::new(valid_id);

        let final_label_id = self.id_gen.next();
        let mut final_block = Block::new(final_label_id);

        block.body.push(Instruction::selection_merge(
            final_label_id,
            spirv::SelectionControl::NONE,
        ));
        function.consume(
            block,
            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id),
        );

        let raw_kind_id = self.id_gen.next();
        valid_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionTypeKHR,
                flag_type_id,
                raw_kind_id,
                query_id,
                intersection_id,
            ));
        let kind_id = if is_committed {
            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
            raw_kind_id
        } else {
            // Remap from the candidate kind to IR
            let condition_id = self.id_gen.next();
            let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
                    as _,
            ));
            valid_block.body.push(Instruction::binary(
                spirv::Op::IEqual,
                self.get_bool_type_id(),
                condition_id,
                raw_kind_id,
                committed_triangle_kind_id,
            ));
            let kind_id = self.id_gen.next();
            valid_block.body.push(Instruction::select(
                flag_type_id,
                kind_id,
                condition_id,
                self.get_constant_scalar(crate::Literal::U32(
                    crate::RayQueryIntersection::Triangle as _,
                )),
                self.get_constant_scalar(crate::Literal::U32(
                    crate::RayQueryIntersection::Aabb as _,
                )),
            ));
            kind_id
        };
        let idx_id = self.get_index_constant(0);
        let access_idx = self.id_gen.next();
        valid_block.body.push(Instruction::access_chain(
            flag_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        valid_block
            .body
            .push(Instruction::store(access_idx, kind_id, None));

        let not_none_comp_id = self.id_gen.next();
        let none_id =
            self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
        valid_block.body.push(Instruction::binary(
            spirv::Op::INotEqual,
            self.get_bool_type_id(),
            not_none_comp_id,
            kind_id,
            none_id,
        ));

        let not_none_label_id = self.id_gen.next();
        let mut not_none_block = Block::new(not_none_label_id);

        let outer_merge_label_id = self.id_gen.next();
        let outer_merge_block = Block::new(outer_merge_label_id);

        valid_block.body.push(Instruction::selection_merge(
            outer_merge_label_id,
            spirv::SelectionControl::NONE,
        ));
        function.consume(
            valid_block,
            Instruction::branch_conditional(
                not_none_comp_id,
                not_none_label_id,
                outer_merge_label_id,
            ),
        );

        let instance_custom_index_id = self.id_gen.next();
        not_none_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
                flag_type_id,
                instance_custom_index_id,
                query_id,
                intersection_id,
            ));
        let instance_id = self.id_gen.next();
        not_none_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
                flag_type_id,
                instance_id,
                query_id,
                intersection_id,
            ));
        let sbt_record_offset_id = self.id_gen.next();
        not_none_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
                flag_type_id,
                sbt_record_offset_id,
                query_id,
                intersection_id,
            ));
        let geometry_index_id = self.id_gen.next();
        not_none_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
                flag_type_id,
                geometry_index_id,
                query_id,
                intersection_id,
            ));
        let primitive_index_id = self.id_gen.next();
        not_none_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
                flag_type_id,
                primitive_index_id,
                query_id,
                intersection_id,
            ));

        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
        // but it's not a property of an intersection.

        let object_to_world_id = self.id_gen.next();
        not_none_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
                transform_type_id,
                object_to_world_id,
                query_id,
                intersection_id,
            ));
        let world_to_object_id = self.id_gen.next();
        not_none_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
                transform_type_id,
                world_to_object_id,
                query_id,
                intersection_id,
            ));

        // instance custom index
        let idx_id = self.get_index_constant(2);
        let access_idx = self.id_gen.next();
        not_none_block.body.push(Instruction::access_chain(
            flag_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        not_none_block.body.push(Instruction::store(
            access_idx,
            instance_custom_index_id,
            None,
        ));

        // instance
        let idx_id = self.get_index_constant(3);
        let access_idx = self.id_gen.next();
        not_none_block.body.push(Instruction::access_chain(
            flag_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        not_none_block
            .body
            .push(Instruction::store(access_idx, instance_id, None));

        let idx_id = self.get_index_constant(4);
        let access_idx = self.id_gen.next();
        not_none_block.body.push(Instruction::access_chain(
            flag_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        not_none_block
            .body
            .push(Instruction::store(access_idx, sbt_record_offset_id, None));

        let idx_id = self.get_index_constant(5);
        let access_idx = self.id_gen.next();
        not_none_block.body.push(Instruction::access_chain(
            flag_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        not_none_block
            .body
            .push(Instruction::store(access_idx, geometry_index_id, None));

        let idx_id = self.get_index_constant(6);
        let access_idx = self.id_gen.next();
        not_none_block.body.push(Instruction::access_chain(
            flag_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        not_none_block
            .body
            .push(Instruction::store(access_idx, primitive_index_id, None));

        let idx_id = self.get_index_constant(9);
        let access_idx = self.id_gen.next();
        not_none_block.body.push(Instruction::access_chain(
            transform_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        not_none_block
            .body
            .push(Instruction::store(access_idx, object_to_world_id, None));

        let idx_id = self.get_index_constant(10);
        let access_idx = self.id_gen.next();
        not_none_block.body.push(Instruction::access_chain(
            transform_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        not_none_block
            .body
            .push(Instruction::store(access_idx, world_to_object_id, None));

        let tri_comp_id = self.id_gen.next();
        let tri_id = self.get_constant_scalar(crate::Literal::U32(
            crate::RayQueryIntersection::Triangle as _,
        ));
        not_none_block.body.push(Instruction::binary(
            spirv::Op::IEqual,
            self.get_bool_type_id(),
            tri_comp_id,
            kind_id,
            tri_id,
        ));

        let tri_label_id = self.id_gen.next();
        let mut tri_block = Block::new(tri_label_id);

        let merge_label_id = self.id_gen.next();
        let merge_block = Block::new(merge_label_id);
        // t
        {
            let block = if is_committed {
                &mut not_none_block
            } else {
                &mut tri_block
            };
            let t_id = self.id_gen.next();
            block.body.push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionTKHR,
                scalar_type_id,
                t_id,
                query_id,
                intersection_id,
            ));
            let idx_id = self.get_index_constant(1);
            let access_idx = self.id_gen.next();
            block.body.push(Instruction::access_chain(
                float_pointer_type_id,
                access_idx,
                blank_intersection_id,
                &[idx_id],
            ));
            block.body.push(Instruction::store(access_idx, t_id, None));
        }
        not_none_block.body.push(Instruction::selection_merge(
            merge_label_id,
            spirv::SelectionControl::NONE,
        ));
        function.consume(
            not_none_block,
            Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
        );

        let barycentrics_id = self.id_gen.next();
        tri_block.body.push(Instruction::ray_query_get_intersection(
            spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
            barycentrics_type_id,
            barycentrics_id,
            query_id,
            intersection_id,
        ));

        let front_face_id = self.id_gen.next();
        tri_block.body.push(Instruction::ray_query_get_intersection(
            spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
            bool_type_id,
            front_face_id,
            query_id,
            intersection_id,
        ));

        let idx_id = self.get_index_constant(7);
        let access_idx = self.id_gen.next();
        tri_block.body.push(Instruction::access_chain(
            barycentrics_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        tri_block
            .body
            .push(Instruction::store(access_idx, barycentrics_id, None));

        let idx_id = self.get_index_constant(8);
        let access_idx = self.id_gen.next();
        tri_block.body.push(Instruction::access_chain(
            bool_pointer_type_id,
            access_idx,
            blank_intersection_id,
            &[idx_id],
        ));
        tri_block
            .body
            .push(Instruction::store(access_idx, front_face_id, None));
        function.consume(tri_block, Instruction::branch(merge_label_id));
        function.consume(merge_block, Instruction::branch(outer_merge_label_id));
        function.consume(outer_merge_block, Instruction::branch(final_label_id));

        let loaded_blank_intersection_id = self.id_gen.next();
        final_block.body.push(Instruction::load(
            intersection_type_id,
            loaded_blank_intersection_id,
            blank_intersection_id,
            None,
        ));
        function.consume(
            final_block,
            Instruction::return_value(loaded_blank_intersection_id),
        );

        function.to_words(&mut self.logical_layout.function_definitions);
        self.ray_query_functions.insert(
            LookupRayQueryFunction::GetIntersection {
                committed: is_committed,
            },
            func_id,
        );
        func_id
    }

    fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word {
        if let Some(&word) = self
            .ray_query_functions
            .get(&LookupRayQueryFunction::Initialize)
        {
            return word;
        }

        let ray_query_type_id = self.get_ray_query_pointer_id();
        let acceleration_structure_type_id =
            self.get_localtype_id(super::LocalType::AccelerationStructure);
        let ray_desc_type_id = self.get_handle_type_id(
            ir_module
                .special_types
                .ray_desc
                .expect("ray desc should be set if ray queries are being initialized"),
        );

        let u32_ty = self.get_u32_type_id();
        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);

        let f32_type_id = self.get_f32_type_id();
        let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);

        let bool_type_id = self.get_bool_type_id();
        let bool_vec3_type_id = self.get_vec3_bool_type_id();

        let (func_id, mut function, arg_ids) = self.write_function_signature(
            &[
                ray_query_type_id,
                acceleration_structure_type_id,
                ray_desc_type_id,
                u32_ptr_ty,
                f32_ptr_ty,
            ],
            self.void_type,
        );

        let query_id = arg_ids[0];
        let acceleration_structure_id = arg_ids[1];
        let desc_id = arg_ids[2];
        let init_tracker_id = arg_ids[3];
        let t_max_tracker_id = arg_ids[4];

        let label_id = self.id_gen.next();
        let mut block = Block::new(label_id);

        let flag_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));

        //Note: composite extract indices and types must match `generate_ray_desc_type`
        let ray_flags_id = self.id_gen.next();
        block.body.push(Instruction::composite_extract(
            flag_type_id,
            ray_flags_id,
            desc_id,
            &[0],
        ));
        let cull_mask_id = self.id_gen.next();
        block.body.push(Instruction::composite_extract(
            flag_type_id,
            cull_mask_id,
            desc_id,
            &[1],
        ));

        let tmin_id = self.id_gen.next();
        block.body.push(Instruction::composite_extract(
            f32_type_id,
            tmin_id,
            desc_id,
            &[2],
        ));
        let tmax_id = self.id_gen.next();
        block.body.push(Instruction::composite_extract(
            f32_type_id,
            tmax_id,
            desc_id,
            &[3],
        ));
        block
            .body
            .push(Instruction::store(t_max_tracker_id, tmax_id, None));

        let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
            size: crate::VectorSize::Tri,
            scalar: crate::Scalar::F32,
        });
        let ray_origin_id = self.id_gen.next();
        block.body.push(Instruction::composite_extract(
            vector_type_id,
            ray_origin_id,
            desc_id,
            &[4],
        ));
        let ray_dir_id = self.id_gen.next();
        block.body.push(Instruction::composite_extract(
            vector_type_id,
            ray_dir_id,
            desc_id,
            &[5],
        ));

        let valid_id = self.ray_query_initialization_tracking.then(||{
            let tmin_le_tmax_id = self.id_gen.next();
            // Check both that tmin is less than or equal to tmax (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06350)
            // and implicitly that neither tmin or tmax are NaN (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06351)
            // because this checks if tmin and tmax are ordered too (i.e: not NaN).
            block.body.push(Instruction::binary(
                spirv::Op::FOrdLessThanEqual,
                bool_type_id,
                tmin_le_tmax_id,
                tmin_id,
                tmax_id,
            ));

            // Check that tmin is greater than or equal to 0 (and
            // therefore also tmax is too because it is greater than
            // or equal to tmin) (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06349).
            let tmin_ge_zero_id = self.id_gen.next();
            let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0));
            block.body.push(Instruction::binary(
                spirv::Op::FOrdGreaterThanEqual,
                bool_type_id,
                tmin_ge_zero_id,
                tmin_id,
                zero_id,
            ));

            // Check that ray origin is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
            let ray_origin_infinite_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::IsInf,
                bool_vec3_type_id,
                ray_origin_infinite_id,
                ray_origin_id,
            ));
            let any_ray_origin_infinite_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::Any,
                bool_type_id,
                any_ray_origin_infinite_id,
                ray_origin_infinite_id,
            ));

            let ray_origin_nan_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::IsNan,
                bool_vec3_type_id,
                ray_origin_nan_id,
                ray_origin_id,
            ));
            let any_ray_origin_nan_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::Any,
                bool_type_id,
                any_ray_origin_nan_id,
                ray_origin_nan_id,
            ));

            let ray_origin_not_finite_id = self.id_gen.next();
            block.body.push(Instruction::binary(
                spirv::Op::LogicalOr,
                bool_type_id,
                ray_origin_not_finite_id,
                any_ray_origin_nan_id,
                any_ray_origin_infinite_id,
            ));

            let all_ray_origin_finite_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::LogicalNot,
                bool_type_id,
                all_ray_origin_finite_id,
                ray_origin_not_finite_id,
            ));

            // Check that ray direction is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
            let ray_dir_infinite_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::IsInf,
                bool_vec3_type_id,
                ray_dir_infinite_id,
                ray_dir_id,
            ));
            let any_ray_dir_infinite_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::Any,
                bool_type_id,
                any_ray_dir_infinite_id,
                ray_dir_infinite_id,
            ));

            let ray_dir_nan_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::IsNan,
                bool_vec3_type_id,
                ray_dir_nan_id,
                ray_dir_id,
            ));
            let any_ray_dir_nan_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::Any,
                bool_type_id,
                any_ray_dir_nan_id,
                ray_dir_nan_id,
            ));

            let ray_dir_not_finite_id = self.id_gen.next();
            block.body.push(Instruction::binary(
                spirv::Op::LogicalOr,
                bool_type_id,
                ray_dir_not_finite_id,
                any_ray_dir_nan_id,
                any_ray_dir_infinite_id,
            ));

            let all_ray_dir_finite_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::LogicalNot,
                bool_type_id,
                all_ray_dir_finite_id,
                ray_dir_not_finite_id,
            ));

            /// Writes spirv to check that less than two booleans are true
            ///
            /// For each boolean: removes it, `and`s it with all others (i.e for all possible combinations of two booleans in the list checks to see if both are true).
            /// Then `or`s all of these checks together. This produces whether two or more booleans are true.
            fn write_less_than_2_true(
                writer: &mut Writer,
                block: &mut Block,
                mut bools: Vec<spirv::Word>,
            ) -> spirv::Word {
                assert!(bools.len() > 1, "Must have multiple booleans!");
                let bool_ty = writer.get_bool_type_id();
                let mut each_two_true = Vec::new();
                while let Some(last_bool) = bools.pop() {
                    for &bool in &bools {
                        let both_true_id = writer.write_logical_and(
                            block,
                            last_bool,
                            bool,
                        );
                        each_two_true.push(both_true_id);
                    }
                }
                let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true`");
                for two_true in each_two_true {
                    let new_all_or_id = writer.id_gen.next();
                    block.body.push(Instruction::binary(
                        spirv::Op::LogicalOr,
                        bool_ty,
                        new_all_or_id,
                        all_or_id,
                        two_true,
                    ));
                    all_or_id = new_all_or_id;
                }

                let less_than_two_id = writer.id_gen.next();
                block.body.push(Instruction::unary(
                    spirv::Op::LogicalNot,
                    bool_ty,
                    less_than_two_id,
                    all_or_id,
                ));
                less_than_two_id
            }

            // Check that at most one of skip triangles and skip AABBs is
            // present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06889)
            let contains_skip_triangles = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::SKIP_TRIANGLES.bits(),
            );
            let contains_skip_aabbs = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::SKIP_AABBS.bits(),
            );

            let not_contain_skip_triangles_aabbs = write_less_than_2_true(
                self,
                &mut block,
                vec![contains_skip_triangles, contains_skip_aabbs],
            );

            // Check that at most one of skip triangles (taken from above check),
            // cull back facing, and cull front face is present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06890)
            let contains_cull_back = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::CULL_BACK_FACING.bits(),
            );
            let contains_cull_front = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::CULL_FRONT_FACING.bits(),
            );

            let not_contain_skip_triangles_cull = write_less_than_2_true(
                self,
                &mut block,
                vec![
                    contains_skip_triangles,
                    contains_cull_back,
                    contains_cull_front,
                ],
            );

            // Check that at most one of force opaque, force not opaque, cull opaque,
            // and cull not opaque are present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06891)
            let contains_opaque = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::FORCE_OPAQUE.bits(),
            );
            let contains_no_opaque = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::FORCE_NO_OPAQUE.bits(),
            );
            let contains_cull_opaque = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::CULL_OPAQUE.bits(),
            );
            let contains_cull_no_opaque = write_ray_flags_contains_flags(
                self,
                &mut block,
                ray_flags_id,
                crate::RayFlag::CULL_NO_OPAQUE.bits(),
            );

            let not_contain_multiple_opaque = write_less_than_2_true(
                self,
                &mut block,
                vec![
                    contains_opaque,
                    contains_no_opaque,
                    contains_cull_opaque,
                    contains_cull_no_opaque,
                ],
            );

            // Combine all checks into a single flag saying whether the call is valid or not.
            self.write_reduce_and(
                &mut block,
                vec![
                    tmin_le_tmax_id,
                    tmin_ge_zero_id,
                    all_ray_origin_finite_id,
                    all_ray_dir_finite_id,
                    not_contain_skip_triangles_aabbs,
                    not_contain_skip_triangles_cull,
                    not_contain_multiple_opaque,
                ],
            )
        });

        let merge_label_id = self.id_gen.next();
        let merge_block = Block::new(merge_label_id);

        // NOTE: this block will be unreachable if initialization tracking is disabled.
        let invalid_label_id = self.id_gen.next();
        let mut invalid_block = Block::new(invalid_label_id);

        let valid_label_id = self.id_gen.next();
        let mut valid_block = Block::new(valid_label_id);

        match valid_id {
            Some(all_valid_id) => {
                block.body.push(Instruction::selection_merge(
                    merge_label_id,
                    spirv::SelectionControl::NONE,
                ));
                function.consume(
                    block,
                    Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
                );
            }
            None => {
                function.consume(block, Instruction::branch(valid_label_id));
            }
        }

        valid_block.body.push(Instruction::ray_query_initialize(
            query_id,
            acceleration_structure_id,
            ray_flags_id,
            cull_mask_id,
            ray_origin_id,
            tmin_id,
            ray_dir_id,
            tmax_id,
        ));

        let const_initialized = self.get_constant_scalar(crate::Literal::U32(
            super::RayQueryPoint::INITIALIZED.bits(),
        ));
        valid_block
            .body
            .push(Instruction::store(init_tracker_id, const_initialized, None));

        function.consume(valid_block, Instruction::branch(merge_label_id));

        if self
            .flags
            .contains(super::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL)
        {
            self.write_debug_printf(
                &mut invalid_block,
                "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
                &[
                    ray_flags_id,
                    tmin_id,
                    tmax_id,
                    ray_origin_id,
                    ray_dir_id,
                ],
            );
        }

        function.consume(invalid_block, Instruction::branch(merge_label_id));

        function.consume(merge_block, Instruction::return_void());

        function.to_words(&mut self.logical_layout.function_definitions);

        self.ray_query_functions
            .insert(LookupRayQueryFunction::Initialize, func_id);
        func_id
    }

    fn write_ray_query_proceed(&mut self) -> spirv::Word {
        if let Some(&word) = self
            .ray_query_functions
            .get(&LookupRayQueryFunction::Proceed)
        {
            return word;
        }

        let ray_query_type_id = self.get_ray_query_pointer_id();

        let u32_ty = self.get_u32_type_id();
        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);

        let bool_type_id = self.get_bool_type_id();
        let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);

        let (func_id, mut function, arg_ids) =
            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id);

        let query_id = arg_ids[0];
        let init_tracker_id = arg_ids[1];

        let block_id = self.id_gen.next();
        let mut block = Block::new(block_id);

        // TODO: perhaps this could be replaced with an OpPhi?
        let proceeded_id = self.id_gen.next();
        let const_false = self.get_constant_scalar(crate::Literal::Bool(false));
        block.body.push(Instruction::variable(
            bool_ptr_ty,
            proceeded_id,
            spirv::StorageClass::Function,
            Some(const_false),
        ));

        let initialized_tracker_id = self.id_gen.next();
        block.body.push(Instruction::load(
            u32_ty,
            initialized_tracker_id,
            init_tracker_id,
            None,
        ));

        let merge_id = self.id_gen.next();
        let mut merge_block = Block::new(merge_id);

        let valid_block_id = self.id_gen.next();
        let mut valid_block = Block::new(valid_block_id);

        let instruction = if self.ray_query_initialization_tracking {
            let is_initialized = write_ray_flags_contains_flags(
                self,
                &mut block,
                initialized_tracker_id,
                super::RayQueryPoint::INITIALIZED.bits(),
            );

            block.body.push(Instruction::selection_merge(
                merge_id,
                spirv::SelectionControl::NONE,
            ));

            Instruction::branch_conditional(is_initialized, valid_block_id, merge_id)
        } else {
            Instruction::branch(valid_block_id)
        };

        function.consume(block, instruction);

        let has_proceeded = self.id_gen.next();
        valid_block.body.push(Instruction::ray_query_proceed(
            bool_type_id,
            has_proceeded,
            query_id,
        ));

        valid_block
            .body
            .push(Instruction::store(proceeded_id, has_proceeded, None));

        let add_flag_finished = self.get_constant_scalar(crate::Literal::U32(
            (super::RayQueryPoint::PROCEED | super::RayQueryPoint::FINISHED_TRAVERSAL).bits(),
        ));
        let add_flag_continuing =
            self.get_constant_scalar(crate::Literal::U32(super::RayQueryPoint::PROCEED.bits()));

        let add_flags_id = self.id_gen.next();
        valid_block.body.push(Instruction::select(
            u32_ty,
            add_flags_id,
            has_proceeded,
            add_flag_continuing,
            add_flag_finished,
        ));
        let final_flags = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::BitwiseOr,
            u32_ty,
            final_flags,
            initialized_tracker_id,
            add_flags_id,
        ));
        valid_block
            .body
            .push(Instruction::store(init_tracker_id, final_flags, None));

        function.consume(valid_block, Instruction::branch(merge_id));

        let loaded_proceeded_id = self.id_gen.next();
        merge_block.body.push(Instruction::load(
            bool_type_id,
            loaded_proceeded_id,
            proceeded_id,
            None,
        ));

        function.consume(merge_block, Instruction::return_value(loaded_proceeded_id));

        function.to_words(&mut self.logical_layout.function_definitions);

        self.ray_query_functions
            .insert(LookupRayQueryFunction::Proceed, func_id);
        func_id
    }

    fn write_ray_query_generate_intersection(&mut self) -> spirv::Word {
        if let Some(&word) = self
            .ray_query_functions
            .get(&LookupRayQueryFunction::GenerateIntersection)
        {
            return word;
        }

        let ray_query_type_id = self.get_ray_query_pointer_id();

        let u32_ty = self.get_u32_type_id();
        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);

        let f32_type_id = self.get_f32_type_id();
        let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);

        let bool_type_id = self.get_bool_type_id();

        let (func_id, mut function, arg_ids) = self.write_function_signature(
            &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id],
            self.void_type,
        );

        let query_id = arg_ids[0];
        let init_tracker_id = arg_ids[1];
        let depth_id = arg_ids[2];
        let t_max_tracker_id = arg_ids[3];

        let block_id = self.id_gen.next();
        let mut block = Block::new(block_id);

        let current_t = self.id_gen.next();
        block.body.push(Instruction::variable(
            f32_ptr_type_id,
            current_t,
            spirv::StorageClass::Function,
            None,
        ));

        let current_t = self.id_gen.next();
        block.body.push(Instruction::variable(
            f32_ptr_type_id,
            current_t,
            spirv::StorageClass::Function,
            None,
        ));

        let valid_id = self.id_gen.next();
        let mut valid_block = Block::new(valid_id);

        let final_label_id = self.id_gen.next();
        let final_block = Block::new(final_label_id);

        let instruction = if self.ray_query_initialization_tracking {
            let initialized_tracker_id = self.id_gen.next();
            block.body.push(Instruction::load(
                u32_ty,
                initialized_tracker_id,
                init_tracker_id,
                None,
            ));

            let proceeded_id = write_ray_flags_contains_flags(
                self,
                &mut block,
                initialized_tracker_id,
                super::RayQueryPoint::PROCEED.bits(),
            );
            let finished_proceed_id = write_ray_flags_contains_flags(
                self,
                &mut block,
                initialized_tracker_id,
                super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
            );

            // Can't find anything to suggest double calling this function is invalid.

            let not_finished_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::LogicalNot,
                bool_type_id,
                not_finished_id,
                finished_proceed_id,
            ));

            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);

            block.body.push(Instruction::selection_merge(
                final_label_id,
                spirv::SelectionControl::NONE,
            ));

            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
        } else {
            Instruction::branch(valid_id)
        };

        function.consume(block, instruction);

        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
        ));
        let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32(
            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
        ));
        let raw_kind_id = self.id_gen.next();
        valid_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionTypeKHR,
                u32_ty,
                raw_kind_id,
                query_id,
                intersection_id,
            ));

        let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32(
            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _,
        ));
        let intersection_aabb_id = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::IEqual,
            bool_type_id,
            intersection_aabb_id,
            raw_kind_id,
            candidate_aabb_id,
        ));

        // Check that the provided t value is between t min and the current committed
        // t value, (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryGenerateIntersectionKHR-06353)

        // Get the tmin
        let t_min_id = self.id_gen.next();
        valid_block.body.push(Instruction::ray_query_get_t_min(
            f32_type_id,
            t_min_id,
            query_id,
        ));

        // Get the current committed t, or tmax if no hit.
        // Basically emulate HLSL's (easier) version
        // Pseudo-code:
        // ````wgsl
        // // start of function
        // var current_t:f32;
        // ...
        // let committed_type_id = RayQueryGetIntersectionTypeKHR<Committed>(query_id);
        // if committed_type_id == Committed_None {
        //     current_t = load(t_max_tracker);
        // } else {
        //     current_t = RayQueryGetIntersectionTKHR<Committed>(query_id);
        // }
        // ...
        // ````

        let committed_type_id = self.id_gen.next();
        valid_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionTypeKHR,
                u32_ty,
                committed_type_id,
                query_id,
                committed_intersection_id,
            ));

        let no_committed = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::IEqual,
            bool_type_id,
            no_committed,
            committed_type_id,
            self.get_constant_scalar(crate::Literal::U32(
                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _,
            )),
        ));

        let next_valid_block_id = self.id_gen.next();
        let no_committed_block_id = self.id_gen.next();
        let mut no_committed_block = Block::new(no_committed_block_id);
        let committed_block_id = self.id_gen.next();
        let mut committed_block = Block::new(committed_block_id);
        valid_block.body.push(Instruction::selection_merge(
            next_valid_block_id,
            spirv::SelectionControl::NONE,
        ));
        function.consume(
            valid_block,
            Instruction::branch_conditional(
                no_committed,
                no_committed_block_id,
                committed_block_id,
            ),
        );

        // Assign t_max to current_t
        let t_max_id = self.id_gen.next();
        no_committed_block.body.push(Instruction::load(
            f32_type_id,
            t_max_id,
            t_max_tracker_id,
            None,
        ));
        no_committed_block
            .body
            .push(Instruction::store(current_t, t_max_id, None));
        function.consume(no_committed_block, Instruction::branch(next_valid_block_id));

        // Assign t_current to current_t
        let latest_t_id = self.id_gen.next();
        committed_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionTKHR,
                f32_type_id,
                latest_t_id,
                query_id,
                intersection_id,
            ));
        committed_block
            .body
            .push(Instruction::store(current_t, latest_t_id, None));
        function.consume(committed_block, Instruction::branch(next_valid_block_id));

        let mut valid_block = Block::new(next_valid_block_id);

        let t_ge_t_min = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::FOrdGreaterThanEqual,
            bool_type_id,
            t_ge_t_min,
            depth_id,
            t_min_id,
        ));
        let t_current = self.id_gen.next();
        valid_block
            .body
            .push(Instruction::load(f32_type_id, t_current, current_t, None));
        let t_le_t_current = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::FOrdLessThanEqual,
            bool_type_id,
            t_le_t_current,
            depth_id,
            t_current,
        ));

        let t_in_range = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::LogicalAnd,
            bool_type_id,
            t_in_range,
            t_ge_t_min,
            t_le_t_current,
        ));

        let call_valid_id = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::LogicalAnd,
            bool_type_id,
            call_valid_id,
            t_in_range,
            intersection_aabb_id,
        ));

        let generate_label_id = self.id_gen.next();
        let mut generate_block = Block::new(generate_label_id);

        let merge_label_id = self.id_gen.next();
        let merge_block = Block::new(merge_label_id);

        valid_block.body.push(Instruction::selection_merge(
            merge_label_id,
            spirv::SelectionControl::NONE,
        ));
        function.consume(
            valid_block,
            Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id),
        );

        generate_block
            .body
            .push(Instruction::ray_query_generate_intersection(
                query_id, depth_id,
            ));

        function.consume(generate_block, Instruction::branch(merge_label_id));
        function.consume(merge_block, Instruction::branch(final_label_id));

        function.consume(final_block, Instruction::return_void());

        function.to_words(&mut self.logical_layout.function_definitions);

        self.ray_query_functions
            .insert(LookupRayQueryFunction::GenerateIntersection, func_id);
        func_id
    }

    fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word {
        if let Some(&word) = self
            .ray_query_functions
            .get(&LookupRayQueryFunction::ConfirmIntersection)
        {
            return word;
        }

        let ray_query_type_id = self.get_ray_query_pointer_id();

        let u32_ty = self.get_u32_type_id();
        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);

        let bool_type_id = self.get_bool_type_id();

        let (func_id, mut function, arg_ids) =
            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);

        let query_id = arg_ids[0];
        let init_tracker_id = arg_ids[1];

        let block_id = self.id_gen.next();
        let mut block = Block::new(block_id);

        let valid_id = self.id_gen.next();
        let mut valid_block = Block::new(valid_id);

        let final_label_id = self.id_gen.next();
        let final_block = Block::new(final_label_id);

        let instruction = if self.ray_query_initialization_tracking {
            let initialized_tracker_id = self.id_gen.next();
            block.body.push(Instruction::load(
                u32_ty,
                initialized_tracker_id,
                init_tracker_id,
                None,
            ));

            let proceeded_id = write_ray_flags_contains_flags(
                self,
                &mut block,
                initialized_tracker_id,
                super::RayQueryPoint::PROCEED.bits(),
            );
            let finished_proceed_id = write_ray_flags_contains_flags(
                self,
                &mut block,
                initialized_tracker_id,
                super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
            );
            // Although it seems strange to call this twice, I (Vecvec) can't find anything to suggest double calling this function is invalid.
            let not_finished_id = self.id_gen.next();
            block.body.push(Instruction::unary(
                spirv::Op::LogicalNot,
                bool_type_id,
                not_finished_id,
                finished_proceed_id,
            ));

            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);

            block.body.push(Instruction::selection_merge(
                final_label_id,
                spirv::SelectionControl::NONE,
            ));

            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
        } else {
            Instruction::branch(valid_id)
        };

        function.consume(block, instruction);

        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
        ));
        let raw_kind_id = self.id_gen.next();
        valid_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionTypeKHR,
                u32_ty,
                raw_kind_id,
                query_id,
                intersection_id,
            ));

        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(
            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _,
        ));
        let intersection_tri_id = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::IEqual,
            bool_type_id,
            intersection_tri_id,
            raw_kind_id,
            candidate_tri_id,
        ));

        let generate_label_id = self.id_gen.next();
        let mut generate_block = Block::new(generate_label_id);

        let merge_label_id = self.id_gen.next();
        let merge_block = Block::new(merge_label_id);

        valid_block.body.push(Instruction::selection_merge(
            merge_label_id,
            spirv::SelectionControl::NONE,
        ));
        function.consume(
            valid_block,
            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
        );

        generate_block
            .body
            .push(Instruction::ray_query_confirm_intersection(query_id));

        function.consume(generate_block, Instruction::branch(merge_label_id));
        function.consume(merge_block, Instruction::branch(final_label_id));

        function.consume(final_block, Instruction::return_void());

        self.ray_query_functions
            .insert(LookupRayQueryFunction::ConfirmIntersection, func_id);

        function.to_words(&mut self.logical_layout.function_definitions);

        func_id
    }

    fn write_ray_query_get_vertex_positions(
        &mut self,
        is_committed: bool,
        ir_module: &crate::Module,
    ) -> spirv::Word {
        if let Some(&word) =
            self.ray_query_functions
                .get(&LookupRayQueryFunction::GetVertexPositions {
                    committed: is_committed,
                })
        {
            return word;
        }

        let (committed_ty, committed_tri_ty) = if is_committed {
            (
                spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32,
                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR
                    as u32,
            )
        } else {
            (
                spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32,
                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
                    as u32,
            )
        };

        let ray_query_type_id = self.get_ray_query_pointer_id();

        let u32_ty = self.get_u32_type_id();
        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);

        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
            *ir_module
                .special_types
                .ray_vertex_return
                .as_ref()
                .expect("must be generated when reading in get vertex position"),
        );
        let ptr_return_ty =
            self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function);

        let bool_type_id = self.get_bool_type_id();

        let (func_id, mut function, arg_ids) = self.write_function_signature(
            &[ray_query_type_id, u32_ptr_ty],
            rq_get_vertex_positions_ty_id,
        );

        let query_id = arg_ids[0];
        let init_tracker_id = arg_ids[1];

        let block_id = self.id_gen.next();
        let mut block = Block::new(block_id);

        let return_id = self.id_gen.next();
        block.body.push(Instruction::variable(
            ptr_return_ty,
            return_id,
            spirv::StorageClass::Function,
            Some(self.get_constant_null(rq_get_vertex_positions_ty_id)),
        ));

        let valid_id = self.id_gen.next();
        let mut valid_block = Block::new(valid_id);

        let final_label_id = self.id_gen.next();
        let mut final_block = Block::new(final_label_id);

        let instruction = if self.ray_query_initialization_tracking {
            let initialized_tracker_id = self.id_gen.next();
            block.body.push(Instruction::load(
                u32_ty,
                initialized_tracker_id,
                init_tracker_id,
                None,
            ));

            let proceeded_id = write_ray_flags_contains_flags(
                self,
                &mut block,
                initialized_tracker_id,
                super::RayQueryPoint::PROCEED.bits(),
            );
            let finished_proceed_id = write_ray_flags_contains_flags(
                self,
                &mut block,
                initialized_tracker_id,
                super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
            );

            let correct_finish_id = if is_committed {
                finished_proceed_id
            } else {
                let not_finished_id = self.id_gen.next();
                block.body.push(Instruction::unary(
                    spirv::Op::LogicalNot,
                    bool_type_id,
                    not_finished_id,
                    finished_proceed_id,
                ));
                not_finished_id
            };

            let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id);
            block.body.push(Instruction::selection_merge(
                final_label_id,
                spirv::SelectionControl::NONE,
            ));
            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
        } else {
            Instruction::branch(valid_id)
        };

        function.consume(block, instruction);

        let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty));
        let raw_kind_id = self.id_gen.next();
        valid_block
            .body
            .push(Instruction::ray_query_get_intersection(
                spirv::Op::RayQueryGetIntersectionTypeKHR,
                u32_ty,
                raw_kind_id,
                query_id,
                intersection_id,
            ));

        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty));
        let intersection_tri_id = self.id_gen.next();
        valid_block.body.push(Instruction::binary(
            spirv::Op::IEqual,
            bool_type_id,
            intersection_tri_id,
            raw_kind_id,
            candidate_tri_id,
        ));

        let generate_label_id = self.id_gen.next();
        let mut vertex_return_block = Block::new(generate_label_id);

        let merge_label_id = self.id_gen.next();
        let merge_block = Block::new(merge_label_id);

        valid_block.body.push(Instruction::selection_merge(
            merge_label_id,
            spirv::SelectionControl::NONE,
        ));
        function.consume(
            valid_block,
            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
        );

        let vertices_id = self.id_gen.next();
        vertex_return_block
            .body
            .push(Instruction::ray_query_return_vertex_position(
                rq_get_vertex_positions_ty_id,
                vertices_id,
                query_id,
                intersection_id,
            ));
        vertex_return_block
            .body
            .push(Instruction::store(return_id, vertices_id, None));

        function.consume(vertex_return_block, Instruction::branch(merge_label_id));
        function.consume(merge_block, Instruction::branch(final_label_id));

        let loaded_pos_id = self.id_gen.next();
        final_block.body.push(Instruction::load(
            rq_get_vertex_positions_ty_id,
            loaded_pos_id,
            return_id,
            None,
        ));

        function.consume(final_block, Instruction::return_value(loaded_pos_id));

        self.ray_query_functions.insert(
            LookupRayQueryFunction::GetVertexPositions {
                committed: is_committed,
            },
            func_id,
        );

        function.to_words(&mut self.logical_layout.function_definitions);

        func_id
    }
}

impl BlockContext<'_> {
    pub(super) fn write_ray_query_function(
        &mut self,
        query: Handle<crate::Expression>,
        function: &crate::RayQueryFunction,
        block: &mut Block,
    ) {
        let query_id = self.cached[query];
        let tracker_ids = *self
            .ray_query_tracker_expr
            .get(&query)
            .expect("not a cached ray query");

        match *function {
            crate::RayQueryFunction::Initialize {
                acceleration_structure,
                descriptor,
            } => {
                let desc_id = self.cached[descriptor];
                let acc_struct_id = self.get_handle_id(acceleration_structure);

                let func = self.writer.write_ray_query_initialize(self.ir_module);

                let func_id = self.gen_id();
                block.body.push(Instruction::function_call(
                    self.writer.void_type,
                    func_id,
                    func,
                    &[
                        query_id,
                        acc_struct_id,
                        desc_id,
                        tracker_ids.initialized_tracker,
                        tracker_ids.t_max_tracker,
                    ],
                ));
            }
            crate::RayQueryFunction::Proceed { result } => {
                let id = self.gen_id();
                self.cached[result] = id;

                let bool_ty = self.writer.get_bool_type_id();

                let func_id = self.writer.write_ray_query_proceed();
                block.body.push(Instruction::function_call(
                    bool_ty,
                    id,
                    func_id,
                    &[query_id, tracker_ids.initialized_tracker],
                ));
            }
            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
                let hit_id = self.cached[hit_t];

                let func_id = self.writer.write_ray_query_generate_intersection();

                let func_call_id = self.gen_id();
                block.body.push(Instruction::function_call(
                    self.writer.void_type,
                    func_call_id,
                    func_id,
                    &[
                        query_id,
                        tracker_ids.initialized_tracker,
                        hit_id,
                        tracker_ids.t_max_tracker,
                    ],
                ));
            }
            crate::RayQueryFunction::ConfirmIntersection => {
                let func_id = self.writer.write_ray_query_confirm_intersection();

                let func_call_id = self.gen_id();
                block.body.push(Instruction::function_call(
                    self.writer.void_type,
                    func_call_id,
                    func_id,
                    &[query_id, tracker_ids.initialized_tracker],
                ));
            }
            crate::RayQueryFunction::Terminate => {}
        }
    }

    pub(super) fn write_ray_query_return_vertex_position(
        &mut self,
        query: Handle<crate::Expression>,
        block: &mut Block,
        is_committed: bool,
    ) -> spirv::Word {
        let fn_id = self
            .writer
            .write_ray_query_get_vertex_positions(is_committed, self.ir_module);

        let query_id = self.cached[query];
        let tracker_id = *self
            .ray_query_tracker_expr
            .get(&query)
            .expect("not a cached ray query");

        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
            *self
                .ir_module
                .special_types
                .ray_vertex_return
                .as_ref()
                .expect("must be generated when reading in get vertex position"),
        );

        let func_call_id = self.gen_id();
        block.body.push(Instruction::function_call(
            rq_get_vertex_positions_ty_id,
            func_call_id,
            fn_id,
            &[query_id, tracker_id.initialized_tracker],
        ));
        func_call_id
    }
}
